## -*- coding: utf-8 -*-
#Load Relevant Packages
import numpy as np
import io, json, sys, os
from collections import Counter, defaultdict
import nltk
import unicodedata, regex
import string, pickle
import pandas as pd

os.chdir('C:/Users/Max/Dropbox/Research/Work_Jorge_Miguel/LSQ Final Submission/LSQ_Replication')
pontuacao = Counter(list(string.punctuation))

#Load machine learning functions from sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer

from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.svm import SVC
from sklearn import metrics
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

vectorizer = CountVectorizer(ngram_range=(1,1),stop_words = None)
word_tokenizer = vectorizer.build_tokenizer()

#%%
portugese_to_english = {
        'AGRICULTURA' : 'Agriculture and Fisheries',
        'CONSTITUICAO' : 'Constitution, Civil Rights and Liberties',
        'CULTURA' : 'Culture',
        'DEFESA' : 'National Defense',
        'ECONOMIA' : 'Economy and Public Works',
        'EDUCACAO' : 'Education and Science',
        'ESTRANGEIRO' : 'Foreign Affairs',
        'ETICA' : 'Ethics, Society and Assembly House Keeping',
        'EUROPA' : 'European Affairs',
        'ORCAMENTO' : 'Public Administration and Budget',
        'ORDENAMENTO' : 'Environment, Territory and Local Government',
        'SAUDE' : 'Health',
        'TRABALHO' : 'Labor and Social Security',
        'avg & / & total' : 'Overall'
}

#%% Define Functions.
def build_stopwords(include = ['numerals', 'names', 'time']):
    ## stopwords, stopnames and stoptime
    ## add to the nltk portuguese stopwords, portuguese names, week days and months
    numerals = [u'i',u'ii',u'iii',u'iv',u'v',u'vi',u'vii',u'viii',u'ix',u'x',u'xi',u'xii',u'xiii',u'xiv',u'xv',u'xvi',u'xvii',u'xviii',u'xix',u'xx',u'xxi']
    nltk_portuguese = nltk.corpus.stopwords.words('portuguese')
    with io.open('data/nomes_clean.csv','r',encoding = 'utf8') as f_in:
        names = Counter([unicodedata.normalize('NFD', line.strip(u'\n').lower().strip()).encode('ascii','ignore').decode('utf8') for line in f_in.readlines()])
    time = Counter([u'segunda',u'terceira',u'quarta',u'quinta',u'sexta',u'sábado',u'domingo'
    		,u'janeiro',u'fevereiro',u'março',u'abril',u'maio',u'junho',u'julho',u'agosto',u'setembro'
    		,u'outubro',u'novembro',u'dezembro'])
    stopwords = Counter(nltk_portuguese)
    if 'numerals' in include:
        stopwords = stopwords + Counter(numerals)
    if 'names' in include:
        stopwords = stopwords + Counter(names)
    if 'time' in include:
        stopwords = stopwords + Counter(time)
    return stopwords

#%%First Pre-Processing Function

def preprocess_speech(threshold = 50, stopwords = build_stopwords(), remove_digits = True):
    #Function to process speeches. Threshold sets word length of speech.
    docs_AR = []
    for leg in [8,9,10,11,12]:
        #print '.',
        with io.open('data/dar_no_government_ids_' + str(leg) + '.json', 'r', encoding = 'utf8') as f:
            main = json.load(f)
            for mp_id in main.keys():
                for doc, _ in main[mp_id]:
                    clean = [token.lower().replace(u'\n',u'') for token in word_tokenizer(doc)]
                    if len(clean) > threshold:
                        if remove_digits == True:
                            temp = ' '.join([token for token in clean if not stopwords[token] and not token.isdigit()])
                        else:
                            temp = ' '.join([token for token in clean if not stopwords[token]])
                        docs_AR.append(temp)
    return docs_AR
#%%
def preprocess_bills(vectorizer, stopwords = build_stopwords(), remove_digits = True):
    with io.open('data/bills_with_classes_v2.json', 'r', encoding = 'utf8') as f:
        main_bills = json.load(f)
    legis = ['VIII', 'IX', 'X', 'XI', 'XII']
    main = defaultdict(list)
    for initiative in main_bills:
        if initiative['legislatura'] in legis:
            main[initiative['classe']].append(initiative['texto'])
    #print('Class Frequency')
    #print(Counter({cl: len(main[cl]) for cl in main.keys()}))
    id2class = {}
    target_motivations = []
    target_technical = []
    target_joint = []
    index_motivations = []
    index_technical = []
    index_joint = []
    target_names = []
    motivations = []
    technical = []
    joint_output = []
    bill_counter = 0
    for i, key in enumerate(main.keys()):
        #Convert words to numeric codes
        target_names.append(key)
        id2class[i] = key
        for both in main[key]:
            counter = 0
            joint = ''
            for doc in both:
                doc = doc.strip()
                if not doc:
                    continue
                clean = [token.lower() for token in word_tokenizer(doc)]
                if remove_digits == True:
                    temp = ' '.join([token for token in clean if not stopwords[token] and not token.isdigit()])
                else:
                    temp = ' '.join([token for token in clean if not stopwords[token]])
                if counter == 0:
                    motivations.append(temp)
                    target_motivations.append(i)
                    index_motivations.append(bill_counter)
                else:
                    technical.append(temp)
                    target_technical.append(i)
                    index_technical.append(bill_counter)
                counter += 1
                joint += ' ' + temp
            joint_output.append(joint)
            target_joint.append(i)
            index_joint.append(bill_counter)
            bill_counter += 1
    X_technical_tfidf = vectorizer.transform(technical)
    X_motivations_tfidf = vectorizer.transform(motivations)
    X_joint_tfidf = vectorizer.transform(joint_output)
    return (X_technical_tfidf, X_motivations_tfidf, X_joint_tfidf), (target_technical, target_motivations, target_joint), id2class

#%%
def classify_speeches(classifier, vectorizer, threshold = 50, remove_digits = True, getdf = False, stopwords = build_stopwords()):
    main_docs_class = defaultdict(lambda: defaultdict(list))
    for leg in [8,9,10,11,12]:
        print('.', end = '')
        with io.open('data/dar_no_government_ids_' + str(leg) + '.json', 'r', encoding = 'utf8') as f:
            main = json.load(f)
            datas = []
            docs = []
            mp_id_list = []
            for mp_id in main.keys():
                for doc, date in main[mp_id]:
                    clean = [token.lower().replace(u'\n',u'') for token in word_tokenizer(doc)]
                    if len(clean) > threshold:
                        if remove_digits == True:
                            temp2 = [token for token in clean if not stopwords[token] and not token.isdigit()]
                            len_2 = len(temp2)
                            temp2 = ' '.join(temp2)
                        else:
                            temp2 = ' '.join([token for token in clean if not stopwords[token]])
                            len_2 = len(temp2)
                            temp2 = ' '.join(temp2)
                        docs.append(temp2)
                        datas.append([date, len_2])
                        mp_id_list.append(mp_id)
        if docs:
            doc_tfidf = vectorizer.transform(docs)
            doc_tfidf = doc_tfidf.toarray()
            stacked_predictions = np.vstack([cl.predict(doc_tfidf) for cl in classifier]).T
            if getdf == True:
                stacked_df = np.hstack([cl.decision_function(doc_tfidf) for cl in classifier])
            for i, cl in enumerate(stacked_predictions):
                if getdf:
                    main_docs_class[mp_id_list[i]][leg].append((datas[i],cl, stacked_df[i]))
                else:
                    main_docs_class[mp_id_list[i]][leg].append((datas[i],cl, None))
    output = []
    for mp_id in main_docs_class.keys():
        for leg in main_docs_class[mp_id].keys():
            for date, cl, ar in main_docs_class[mp_id][leg]:
                output += [[mp_id,leg,date[0], date[1], cl, ar]]
    return output
#%%SVM Validation
validate_svm = True #Set this on to skip cross-validation and move directly to classification.
if validate_svm == True:
    print('Validating SVM Model Specifications')
    all_f1s = pd.DataFrame()
    validation_combinations = [
            {'ngrams': 'bigrams', 'min_df': 0.005, 'max_df': 0.9, 'stopwords': 'full', 'tfidf':True},
            {'ngrams': 'unigrams', 'min_df': 0.005, 'max_df': 0.9, 'stopwords': 'full', 'tfidf':True}, #Unigrams
            {'ngrams': 'bigrams', 'min_df': 0.0, 'max_df': 1.0, 'stopwords': 'full', 'tfidf':True}, #LimitedFiltering
            {'ngrams': 'bigrams', 'min_df': 0.005, 'max_df': 1.0, 'stopwords': 'full', 'tfidf':True}, #Limited Filtering
            {'ngrams': 'bigrams', 'min_df': 0.0, 'max_df': 0.9, 'stopwords': 'full', 'tfidf':True},  #Limited Filtering
            {'ngrams': 'bigrams', 'min_df': 0.005, 'max_df': 0.9, 'stopwords': '', 'tfidf':True}, #Limited Stopwords
            {'ngrams': 'bigrams', 'min_df': 0.005, 'max_df': 0.9, 'stopwords': 'full', 'tfidf':False}#No tf-idf
    ]
    counter = 0
    for v in validation_combinations: #Loop Through Validation Combinations
        print(v)
        if v['ngrams'] == 'bigrams':
            nr = (1,2)
        else:
            nr = (1,1)
        if v['stopwords'] == 'full':
            sw = build_stopwords()
        else:
            sw = build_stopwords(include = v['stopwords'])
        speech_docs = preprocess_speech(stopwords = sw) #Get dictionary around speech words.
        speech_vectorizer = TfidfVectorizer(ngram_range=nr,min_df = v['min_df'],max_df = v['max_df'], stop_words = None,use_idf=v['tfidf'],smooth_idf = False)
        speech_vectorizer.fit(speech_docs) #Get fitted tf-idf with correct n-gram vocabulary
        all_tfidf, all_target, bill_id2class = preprocess_bills(vectorizer = speech_vectorizer, stopwords = sw) #Format Bills.
        data_type = ['technical', 'motivations', 'joint']
        print('Running SVMs')
        for d in range(3):
            print(data_type[d])
           #Create 80/20 cross-validation set.
            X_train, X_test, y_train, y_test = train_test_split(all_tfidf[d], all_target[d],
                stratify = all_target[d], test_size=0.2, random_state=1)
            store_f1s = []
            C_vals = [0.1, 0.5, 1, 2.5, 5]
            for C in C_vals:
                clf = LinearSVC(multi_class = 'ovr', C = C)
                clf.fit(X_train, y_train)
                pred = [bill_id2class[x] for x in clf.predict(X_test)]
                y_test_ = [bill_id2class[x] for x in y_test]
                full_report = metrics.classification_report(y_test_, pred)
                if counter == 0 and d == 2 and C == 1:
                    svm_full_report = full_report
                agg_f1 = metrics.classification.f1_score(y_test_, pred, average='weighted')
                store_f1s += [[metrics.classification.f1_score(y_test_, pred, average='weighted'), metrics.classification.f1_score(y_test_, pred, average='macro')]]
            #print(store_f1s)
            store_f1s = pd.DataFrame(store_f1s)
            store_f1s.columns = ['f1_score_weighted_avg', 'f1_score_average']
            store_f1s['C'] = C_vals
            store_f1s['n_features'] = X_train.shape[1]
            store_f1s['class'] = counter
            for i in v.items():
                store_f1s[i[0]] = i[1]
            store_f1s['type'] = data_type[d]
            all_f1s = all_f1s.append(store_f1s)
        del speech_vectorizer
        counter += 1
    all_f1s.to_csv('figures/cv_svm.csv', index = False)


    svm_full_report = regex.sub('[^9]* (?=AGRI)', '', svm_full_report)
    svm_full_report = [' & '.join(regex.split(' +', i.strip())) + '\\\\' for i in svm_full_report.split('\n')]

    counter = 0
    for l in svm_full_report:
        if l.startswith('\\\\'):
            s_l = ''
        elif l.startswith('avg'):
            s_l = regex.sub('avg & / & total', 'Overall', l)
        else:
            s_l = l.split(' & ')
            s_l[0] = portugese_to_english[s_l[0]]
            s_l = ' & '.join(s_l)
        svm_full_report[counter] = s_l
        counter += 1
    svm_full_report = '\n'.join(svm_full_report)
    svm_full_report = [i.strip() for i in regex.split('Overall', svm_full_report)]
    svm_full_report = svm_full_report[0] + '\hline\n Overall' + svm_full_report[1]
    with io.open('figures/bill_report.tex', 'w') as f: f.write(svm_full_report)

    number_of_bills = Counter(all_target[1])
    with io.open('figures/number_of_bills.tex', 'w') as f: f.write('\n'.join([str(i) + ' & ' + str(j) + '\\\\' for i,j in zip(list(portugese_to_english.values()), number_of_bills.values())]))
    print(sum(number_of_bills.values()))
#%%
print('Train Final Model on Full Data')
sw = build_stopwords()
speech_docs = preprocess_speech(stopwords = sw) #Get dictionary around speech words.
speech_vectorizer = TfidfVectorizer(ngram_range=(1,2),min_df = 0.005,max_df = 0.9, stop_words = None, use_idf=True, smooth_idf = False)
speech_vectorizer.fit(speech_docs) #Get fitted tf-idf with correct n-gram vocabulary
all_tfidf, all_target, bill_id2class = preprocess_bills(vectorizer = speech_vectorizer, stopwords = sw) #Format bills.
#Fit the classifier on all bill data, using all words.
main_clf = LinearSVC(multi_class = 'ovr',C=1)
main_clf.fit(all_tfidf[2], all_target[2])
#%% - Predict Speeches
predict_speech = classify_speeches(classifier = [main_clf], vectorizer = speech_vectorizer)
predict_speech = pd.DataFrame(predict_speech)
predict_speech.columns = ['mp_id', 'legislature', 'date', 'length_words', 'class', 'null']
predict_speech['class'] = [i[0] for i in predict_speech['class']]
predict_speech['class_name'] = [bill_id2class[x] for x in predict_speech['class']]
predict_speech.to_csv('data/classifed_speeches.csv', index = False, encoding = 'UTF-8')

#%%
print('Examining Annotated Speeches')
#Predicting the Annotated Speeches
with io.open('gold_annotations/gold_and_annotated_speeches.json', 'r', encoding = 'utf8') as f:
    main = json.load(f)

annotated_speeches = [i['speech'] for i in main]
annotated_speeches = speech_vectorizer.transform(annotated_speeches)
annotated_gold = np.array([i['class_gold'] for i in main])
annotated_dict = {i['class_pred'][0]: i['class_pred'][1] for i in main}
annotated_dict[1] = 'ETICA'
annotated_gold_named = np.array([annotated_dict[i['class_gold']] for i in main])
annotated_baseline = np.array([i['class_pred'][0] for i in main])

store_annotation_f1s = []
for d in [0,1,2]:
    C_vals = [0.1, 0.5, 1, 2.5, 5]
    for C in C_vals:
        clf = LinearSVC(multi_class = 'ovr',C=C)
        clf.fit(all_tfidf[d], all_target[d])
        speech_pred = clf.predict(annotated_speeches)
        speech_pred = [bill_id2class[i] for i in speech_pred]
        if d == 2 and C == 1:
            speech_full_report = metrics.classification_report(annotated_gold_named, speech_pred)
        store_annotation_f1s += [[d, C, metrics.classification.f1_score(annotated_gold_named, speech_pred, average = 'macro'), metrics.classification.f1_score(annotated_gold_named, speech_pred, average = 'weighted')]]
store_annotation_f1s = pd.DataFrame(store_annotation_f1s)
store_annotation_f1s.columns = ['data_type', 'C', 'f1_avg', 'f1_weighted']
store_annotation_f1s['data_type'] = np.array(['Technical', 'Motivations', 'All'])[store_annotation_f1s['data_type']]
store_annotation_f1s.to_csv('figures/cv_speeches_svm.csv', index = False)

speech_full_report = regex.sub('[^9]* (?=AGRI)', '', speech_full_report)
speech_full_report = [' & '.join(regex.split(' +', i.strip())) + '\\\\' for i in speech_full_report.split('\n')]

counter = 0
for l in speech_full_report:
    if l.startswith('\\\\'):
        s_l = ''
    elif l.startswith('avg'):
        s_l = regex.sub('avg & / & total', 'Overall', l)
    else:
        s_l = l.split(' & ')
        s_l[0] = portugese_to_english[s_l[0]]
        s_l = ' & '.join(s_l)
    speech_full_report[counter] = s_l
    counter += 1
speech_full_report = '\n'.join(speech_full_report)
speech_full_report = [i.strip() for i in regex.split('Overall', speech_full_report)]
speech_full_report = speech_full_report[0] + '\hline\n Overall' + speech_full_report[1]
with io.open('figures/speech_report.tex', 'w') as f: f.write(speech_full_report)
#%%
#Validation Accuracy for Other Models on Speeches:
#Only do for full data.
store_est = []
for d in [2]:
    print(d)
    lasso_cv = LogisticRegressionCV(Cs = np.arange(1,25, step = 1), penalty = 'l1', solver = 'liblinear').fit(all_tfidf[d], all_target[d])
    ridge_cv = LogisticRegressionCV(Cs = np.arange(1,25, step = 1), penalty = 'l2', solver = 'liblinear').fit(all_tfidf[d], all_target[d])
    rf =  RandomForestClassifier(n_estimators = 20).fit(all_tfidf[d], all_target[d])
    nn = MLPClassifier(hidden_layer_sizes=(100,), activation = 'tanh').fit(all_tfidf[d], all_target[d])
    svm =  LinearSVC(C=1).fit(all_tfidf[d], all_target[d])
    counter = 0
    for est in [svm, lasso_cv, ridge_cv, rf, nn]:#[lasso_cv, ridge_cv, rf]:
        speech_pred = est.predict(annotated_speeches)
        speech_pred = [bill_id2class[i] for i in speech_pred]
        #metrics.classification_report(annotated_gold_named, speech_pred)
        store_est += [[d, counter, metrics.classification.f1_score(annotated_gold_named, speech_pred, average = 'macro'), metrics.classification.f1_score(annotated_gold_named, speech_pred, average = 'weighted')]]
        counter += 1
store_est = pd.DataFrame(store_est)
store_est.columns = ['data_type', 'model', 'f1_average', 'f1_weighted']
store_est['model'] = np.array(['SVM', 'LASSO', 'Ridge', 'Random Forest', 'Neural Net'])[store_est['model']]
store_est['data_type'] = np.array(['Technical', 'Motivations', 'All'])[store_est['data_type']]
store_est.to_csv('figures/cv_speeches_alternative_models.csv', index = False)
#%%
#Validation Accuracy for Bills (80/20 Split)
#Only do for full data.
store_est = []
for d in [2]:
    print(d)
    X_train, X_test, y_train, y_test = train_test_split(all_tfidf[d], all_target[d],
        stratify = all_target[d], test_size=0.2, random_state=1)
    fmt_y = [bill_id2class[i] for i in y_test]
    lasso_cv = LogisticRegressionCV(Cs = np.arange(1,25, step = 1), penalty = 'l1', solver = 'liblinear').fit(X_train, y_train)
    ridge_cv = LogisticRegressionCV(Cs = np.arange(1,25, step = 1), penalty = 'l2', solver = 'liblinear').fit(X_train, y_train)
    rf =  RandomForestClassifier(n_estimators = 20).fit(X_train, y_train)
    svm = LinearSVC(C = 1, multi_class = 'ovr').fit(X_train, y_train)
    nn = MLPClassifier(hidden_layer_sizes=(100,), activation = 'tanh').fit(X_train, y_train)
    counter = 0
    for est in [svm, lasso_cv, ridge_cv, rf, nn]:#[lasso_cv, ridge_cv, rf]:
        speech_pred = est.predict(X_test)
        speech_pred = [bill_id2class[i] for i in speech_pred]
        #metrics.classification_report(annotated_gold_named, speech_pred)
        store_est += [[d, counter, metrics.classification.f1_score(fmt_y, speech_pred, average = 'macro'), 
                       metrics.classification.f1_score(fmt_y, speech_pred, average = 'weighted')]]
        counter += 1
store_est = pd.DataFrame(store_est)
store_est.columns = ['data_type', 'model', 'f1_average', 'f1_weighted']
store_est['model'] = np.array(['SVM', 'LASSO', 'Ridge', 'Random Forest', 'Neural Net'])[store_est['model']]
store_est['data_type'] = np.array(['Technical', 'Motivations', 'All'])[store_est['data_type']]
store_est.to_csv('figures/cv_bills_alternative_models.csv', index = False)

#%%Create some final figures/tables.
all_f1s = pd.read_csv('figures/cv_svm.csv')
all_other_f1s = pd.read_csv('figures/cv_bills_alternative_models.csv')
store_est = pd.read_csv('figures/cv_speeches_alternative_models.csv')
store_annotation_f1s = pd.read_csv('figures/cv_speeches_svm.csv')

cv_table = all_f1s.loc[all_f1s['type'] == 'joint']

print(store_annotation_f1s[store_annotation_f1s['C'] == 1])
print(store_annotation_f1s.groupby(['data_type'], sort=False)['f1_weighted'].max())
print(store_annotation_f1s.groupby(['data_type'], sort=False)['f1_avg'].max())

all_f1s = all_f1s.loc[all_f1s['class'] == 5]
best_all_avg = all_f1s[all_f1s.groupby(['type'])['f1_score_average'].transform(max) == all_f1s['f1_score_average']]

cv_table['class'] = cv_table['class'] + 1
tbl_cv = pd.crosstab(cv_table['class'], cv_table['C'], cv_table['f1_score_average'], aggfunc=sum)
tbl_cv.loc[:,:] = np.round(np.array(tbl_cv), 3)
tbl_cv = tbl_cv.to_latex(header = True, index = True,float_format=None)
tbl_cv = regex.sub('\n[^0-9]+?tabular}', '', regex.sub('^[^~]+midrule', '', tbl_cv))

w_tbl_cv = pd.crosstab(cv_table['class'], cv_table['C'], cv_table['f1_score_weighted_avg'], aggfunc=sum)
w_tbl_cv.loc[:,:] = np.round(np.array(w_tbl_cv), 3)
w_tbl_cv = w_tbl_cv.to_latex(header = True, index = True, float_format=None)
w_tbl_cv = regex.sub('\n[^0-9]+?tabular}', '', regex.sub('^[^~]+midrule', '', w_tbl_cv))

with io.open('figures/svm_cv_weighted.tex', 'w') as f: f.write(w_tbl_cv)
with io.open('figures/svm_cv_avg.tex', 'w') as f: f.write(tbl_cv)

#%%
tab_other_avg = store_est.loc[store_est['data_type'] == 'All', ['model', 'f1_average']]
tab_other_avg = pd.concat([tab_other_avg.reset_index(drop = True), all_other_f1s.loc[:,'f1_average']], axis = 1, ignore_index = True)
tab_other_avg.loc[:,1:] = np.round(np.array(tab_other_avg.loc[:,1:]), 3)
tab_other_avg.columns = ['Method', 'Speeches', 'Bills']

tab_other_wavg = store_est.loc[store_est['data_type'] == 'All', ['model', 'f1_weighted']]
tab_other_wavg = pd.concat([tab_other_wavg.reset_index(drop = True), all_other_f1s.loc[:,'f1_weighted']], axis = 1, ignore_index = True)
tab_other_wavg.loc[:,1:] = np.round(np.array(tab_other_wavg.loc[:,1:]), 3)
tab_other_wavg.columns = ['Method', 'Speeches', 'Bills']

with io.open('figures/alt_methods_avg.tex', 'w') as f: f.write(regex.sub('\n[^0-9]+?tabular}|^.*tabular.*\n\\\\toprule\n?|\\\\midrule', '', tab_other_avg.to_latex(header = True, index = False)))
with io.open('figures/alt_methods_weight_avg.tex', 'w') as f: f.write(regex.sub('\n[^0-9]+?tabular}|^.*tabular.*\n\\\\toprule\n?|\\\\midrule', '', tab_other_wavg.to_latex(header = True, index = False)))

